"""
Visualization module for Demographics & Capital Flows project.

Generates:
1. Age-group coefficient plots (replicating Koomen Figure 2)
2. Country demographic contribution time series (Koomen Figure 3)
3. Model comparison charts
4. Projection plots
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path

OUTPUT_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/output")
FIG_DIR = OUTPUT_DIR / "figures"
TAB_DIR = OUTPUT_DIR / "tables"

# Ensure directories exist
FIG_DIR.mkdir(parents=True, exist_ok=True)
TAB_DIR.mkdir(parents=True, exist_ok=True)

AGE_LABELS = [
    '0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34',
    '35-39', '40-44', '45-49', '50-54', '55-59', '60-64', '65-69',
    '70-74', '75-79', '80+'
]


def plot_age_coefficients(alpha, alpha_se=None, title="Implied Age-Group Coefficients",
                          filename="age_coefficients.png"):
    """
    Plot implied age-group coefficients α_g with confidence intervals.

    Replicates Koomen & Wicht (2020) Figure 2.

    Expected shape: negative for young dependents, positive hump peaking
    around 55-59 (prime savers), negative again for elderly.
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    x = np.arange(len(alpha))
    ax.bar(x, alpha, color='steelblue', alpha=0.7, edgecolor='navy', linewidth=0.5)

    if alpha_se is not None:
        ax.errorbar(x, alpha, yerr=1.96 * alpha_se, fmt='none', color='black',
                    capsize=3, linewidth=1)

    ax.axhline(y=0, color='black', linewidth=0.8, linestyle='-')
    ax.set_xticks(x)
    ax.set_xticklabels(AGE_LABELS, rotation=45, ha='right', fontsize=9)
    ax.set_xlabel('Age Group', fontsize=11)
    ax.set_ylabel('Coefficient (effect on CA/GDP, pp)', fontsize=11)
    ax.set_title(title, fontsize=13)
    ax.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_country_demographic_contributions(profiles, countries=None,
                                           filename="demographic_contributions.png"):
    """
    Plot demographic contribution to CA/GDP for selected countries.

    Replicates Koomen & Wicht (2020) Figure 3.
    """
    if countries is None:
        countries = list(profiles.keys())[:9]

    n_countries = len(countries)
    n_cols = min(3, n_countries)
    n_rows = (n_countries + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows),
                             sharex=False, sharey=False)
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)

    for idx, country in enumerate(countries):
        row, col = idx // n_cols, idx % n_cols
        ax = axes[row, col]

        if country not in profiles:
            ax.set_visible(False)
            continue

        cdf = profiles[country].sort_values('year')

        # Plot demographic contribution
        ax.plot(cdf['year'], cdf['demo_contribution'], 'b-', linewidth=2,
                label='Demo. contribution')

        # Plot actual CA/GDP if available
        if 'ca_gdp' in cdf.columns:
            ax.plot(cdf['year'], cdf['ca_gdp'], 'k--', linewidth=1, alpha=0.5,
                    label='Actual CA/GDP')

        ax.axhline(y=0, color='gray', linewidth=0.5, linestyle='-')
        ax.set_title(country, fontsize=12, fontweight='bold')
        ax.set_xlabel('Year', fontsize=9)
        ax.set_ylabel('CA/GDP (pp)', fontsize=9)
        ax.grid(alpha=0.3)
        ax.legend(fontsize=7)

    # Hide unused subplots
    for idx in range(n_countries, n_rows * n_cols):
        row, col = idx // n_cols, idx % n_cols
        axes[row, col].set_visible(False)

    plt.suptitle('Demographic Contribution to Current Account / GDP',
                 fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_projections(projection_df, countries=None, filename="projections.png"):
    """
    Plot projected demographic contributions 2025-2060 for key countries.
    """
    if countries is None:
        countries = ['CHN', 'IND', 'IDN', 'JPN', 'USA', 'NGA', 'DEU', 'BRA']

    fig, ax = plt.subplots(figsize=(12, 7))

    colors = plt.cm.tab10(np.linspace(0, 1, len(countries)))

    for i, country in enumerate(countries):
        cdf = projection_df[projection_df['iso3'] == country].sort_values('year')
        if len(cdf) > 0:
            ax.plot(cdf['year'], cdf['demo_contribution'],
                    color=colors[i], linewidth=2, label=country)

    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.axvline(x=2024, color='gray', linewidth=0.8, linestyle='--', alpha=0.5)
    ax.set_xlabel('Year', fontsize=12)
    ax.set_ylabel('Demographic Contribution to CA/GDP (pp)', fontsize=12)
    ax.set_title('Projected Demographic Pressure on Current Accounts', fontsize=14)
    ax.legend(loc='best', fontsize=10)
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_model_comparison(comparison_df, filename="model_comparison.png"):
    """Bar chart comparing R² across model specifications."""
    fig, ax = plt.subplots(figsize=(8, 5))

    models = comparison_df['Model']
    r2 = comparison_df['R²']

    bars = ax.barh(models, r2, color='steelblue', edgecolor='navy', alpha=0.7)
    ax.set_xlabel('R²', fontsize=12)
    ax.set_title('Model Comparison: R-squared', fontsize=13)

    for bar, val in zip(bars, r2):
        ax.text(val + 0.005, bar.get_y() + bar.get_height() / 2,
                f'{val:.3f}', va='center', fontsize=10)

    ax.set_xlim(0, max(r2) * 1.15)
    ax.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_residual_map(country_resid_df, filename="residual_map.png"):
    """
    Plot average model residuals by country (horizontal bar chart).
    """
    df = country_resid_df.sort_values('mean_resid').head(40)

    fig, ax = plt.subplots(figsize=(10, 12))

    colors = ['green' if x > 0 else 'red' for x in df['mean_resid']]
    ax.barh(df['iso3'], df['mean_resid'], color=colors, alpha=0.7)
    ax.axvline(x=0, color='black', linewidth=0.8)
    ax.set_xlabel('Mean Residual (CA/GDP pp)', fontsize=12)
    ax.set_title('Model Residuals by Country\n(Positive = Higher CA than predicted)', fontsize=13)
    ax.grid(axis='x', alpha=0.3)

    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def plot_china_counterfactual(actual_profile, counterfactual_profile,
                              filename="china_counterfactual.png"):
    """
    Plot China actual vs. counterfactual (replacement fertility) demographic contribution.
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    if actual_profile is not None and len(actual_profile) > 0:
        ax.plot(actual_profile['year'], actual_profile['demo_contribution'],
                'r-', linewidth=2, label='Actual (one-child policy)')

    if counterfactual_profile is not None and len(counterfactual_profile) > 0:
        ax.plot(counterfactual_profile['year'], counterfactual_profile['demo_contribution'],
                'b--', linewidth=2, label='Counterfactual (replacement fertility)')

    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.axvline(x=1980, color='gray', linewidth=0.8, linestyle=':', alpha=0.5,
               label='One-child policy start')
    ax.set_xlabel('Year', fontsize=12)
    ax.set_ylabel('Demographic Contribution to CA/GDP (pp)', fontsize=12)
    ax.set_title('China: Actual vs. Counterfactual Demographic Pressure', fontsize=14)
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(FIG_DIR / filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved {filename}")


def save_regression_table(model, feature_names, filename="regression_results.csv"):
    """Save regression results as CSV table."""
    df = model.to_dataframe(feature_names)
    df.to_csv(TAB_DIR / filename, index=False)
    print(f"  Saved {filename}")
    return df


def save_latex_table(model, feature_names, filename="regression_results.tex"):
    """Save regression results as LaTeX table."""
    df = model.to_dataframe(feature_names)

    # Format for LaTeX
    lines = [
        r"\begin{table}[htbp]",
        r"\centering",
        r"\caption{Current Account / GDP Regression Results}",
        r"\begin{tabular}{lcccc}",
        r"\hline\hline",
        r"Variable & Coefficient & Std. Error & t-statistic & p-value \\",
        r"\hline",
    ]

    for _, row in df.iterrows():
        sig = ''
        if row['p_value'] < 0.01:
            sig = r'$^{***}$'
        elif row['p_value'] < 0.05:
            sig = r'$^{**}$'
        elif row['p_value'] < 0.1:
            sig = r'$^{*}$'

        lines.append(
            f"{row['variable']} & {row['coefficient']:.4f}{sig} & "
            f"{row['std_error']:.4f} & {row['t_statistic']:.2f} & "
            f"{row['p_value']:.4f} \\\\"
        )

    lines.extend([
        r"\hline",
        f"N obs & \\multicolumn{{4}}{{c}}{{{model.n_obs:,}}} \\\\",
        f"N countries & \\multicolumn{{4}}{{c}}{{{model.n_countries}}} \\\\",
        f"$R^2$ & \\multicolumn{{4}}{{c}}{{{model.r_squared:.4f}}} \\\\",
        f"$\\rho$ (AR1) & \\multicolumn{{4}}{{c}}{{{model.rho:.4f}}} \\\\",
        r"\hline\hline",
        r"\end{tabular}",
        r"\end{table}",
    ])

    with open(TAB_DIR / filename, 'w') as f:
        f.write('\n'.join(lines))
    print(f"  Saved {filename}")
